import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import pymc3 as pm
import toolz as tz
import itertools
import arviz as az
%matplotlib inline
## Generate data
GROUP_PARAMS = [
dict(label='A', n=54, x_mean=2, x_sd=0.8, y_int=0, y_slope=1.5, y_sd=1),
dict(label='B', n=18, x_mean=12, x_sd=1.2, y_int=-7, y_slope=1.0, y_sd=1),
dict(label='C', n=6, x_mean=8, x_sd=1., y_int=-15, y_slope=1.2, y_sd=1)
]
def points(params):
keys = ['label', 'n', 'x_mean', 'x_sd', 'y_int', 'y_slope', 'y_sd']
label, n, x_mean, x_sd, y_int, y_slope, y_sd = tz.get(keys, params)
x = x_mean + x_sd * np.random.randn(n)
y = y_int + y_slope * x + y_sd * np.random.randn(n)
return {'x': x, 'y': y, 'label': label}
def dataframe(group_params):
return pd.concat([pd.DataFrame(points(p)) for p in group_params], ignore_index=True)
dataframe(GROUP_PARAMS).to_csv('group_sample_data.csv', index=False)
data = pd.read_csv('group_sample_data.csv')
data.head()
# Plotting stuff
COLORS = ['rgb(251,128,114)', 'rgb(128,177,211)', 'rgb(179,222,105)']
LAYOUT = go.Layout(
width=800, height=800,
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white'
)
def group_scatter(group, label, color):
return go.Scatter(
x=group['x'], y=group['y'],
name=label,
mode='markers',
marker=dict(size=8, color=color)
)
def all_scatters(data, colors):
colors_cycle = itertools.cycle(colors)
return [
group_scatter(g, l, next(colors_cycle))
for l, g in data.groupby('label')
]
def line_points(xmin, xmax, intercept, slope):
ymin, ymax = (
intercept + slope * xmin,
intercept + slope * xmax
)
return np.array([xmin, xmax]), np.array([ymin, ymax]).T
def line_trace(x, y, color):
return go.Scatter(
x=x, y=y,
mode='lines',
opacity=0.2,
line=dict(
width=1,
color=color
),
showlegend=False
)
def all_lines(xmin, xmax, intercept, slope, color='rgb(180,180,180)'):
x, ys = line_points(xmin, xmax, intercept, slope)
return [line_trace(x, y, color) for y in ys]
def group_lines(trace, data, colors=COLORS):
int_a, int_b, int_c = trace['intercept'].T
slp_a, slp_b, slp_c = trace['slope'].T
(xa_min, xa_max), (xb_min, xb_max), (xc_min, xc_max) = data.groupby('label')['x'].agg(['min', 'max']).values
lines_a = all_lines(xa_min, xa_max, int_a[::10], slp_a[::10], color=colors[0])
lines_b = all_lines(xb_min, xb_max, int_b[::10], slp_b[::10], color=colors[1])
lines_c = all_lines(xc_min, xc_max, int_c[::10], slp_c[::10], color=colors[2])
return lines_a + lines_b + lines_c
def figure(traces, layout=LAYOUT):
return go.Figure(data=traces, layout=layout)
f = figure(all_scatters(data, COLORS))
f.show()
with pm.Model() as model_cp:
intercept = pm.Normal('intercept', mu=data['y'].mean(), sigma=10)
slope = pm.Normal('slope', mu=0, sigma=1)
mu = pm.Deterministic('mu_y', intercept + slope * data['x'].values)
sigma = pm.HalfNormal('sigma_y', sigma=10)
pm.Normal('y', mu=mu, sigma=sigma, observed=data['y'].values)
pm.model_to_graphviz(model_cp)
with model_cp:
trace_cp = pm.sample()
az.plot_trace(trace_cp, var_names=['intercept', 'slope']);
az.plot_posterior(trace_cp, var_names=['intercept', 'slope']);
lines = all_lines(data['x'].min(), data['x'].max(), trace_cp['intercept'][::10], trace_cp['slope'][::10])
traces = all_scatters(data, COLORS) + lines
fig = figure(traces)
fig.show()
with pm.Model() as model_up:
intercept = pm.Normal('intercept', mu=0, sigma=10, shape=3)
slope = pm.Normal('slope', mu=0, sigma=1, shape=3)
sigma = pm.HalfNormal('sigma', sigma=5, shape=3)
for i, (label, group) in enumerate(data.groupby('label')):
mu = intercept[i] + slope[i] * group['x'].values
pm.Normal(f'y_{label}', mu=mu, sigma=sigma[i], observed=group['y'].values)
pm.model_to_graphviz(model_up)
with model_up:
trace_up = pm.sample()
az.plot_posterior(trace_up, var_names=['slope']);
fig = figure(all_scatters(data, COLORS) + group_lines(trace_up, data))
fig.show()
with pm.Model() as model_pp:
# A shared slope hyperparameter
mu_slope = pm.Normal('mu_slope', mu=0, sigma=1)
intercept = pm.Normal('intercept', mu=0, sigma=10, shape=3)
slope = pm.Normal('slope', mu=mu_slope, sigma=1, shape=3)
sigma = pm.HalfNormal('sigma', sigma=5, shape=3)
for i, (label, group) in enumerate(data.groupby('label')):
mu = intercept[i] + slope[i] * group['x'].values
pm.Normal(f'y_{label}', mu=mu, sigma=sigma[i], observed=group['y'].values)
pm.model_to_graphviz(model_pp)
with model_pp:
trace_pp = pm.sample()
az.plot_posterior(trace_pp, var_names=['mu_slope', 'slope']);
fig = figure(all_scatters(data, COLORS) + group_lines(trace_pp, data))
fig.show()
slope_c = trace_pp['slope'][:, 2]
slope_c[:10]
(slope_c > 0.0).mean()
((slope_c > 1) & (slope_c < 1.5)).mean()